import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import support_functions as sp
import sys
from scipy.optimize import curve_fit
from scipy.optimize import minimize
from scipy.interpolate import interp1d
from scipy.interpolate import UnivariateSpline
from scipy.optimize import root
from scipy.stats import lognorm
from scipy.stats import beta
from scipy.special import expi
from numpy.linalg import inv
import glob
#import pynverse
import os
import tikzplotlib

FIG_SIZE_1x2 = (9,4)
FIG_SIZE_2x1 = (5,8)
FIG_SIZE_2x2 = (10,8)

def philips_slopes_kimball(shares, markups, passthroughs, frisch=0.2, inc_elast=0.2, pi=0.5):
	calvo = pi * np.ones(shares.shape)
	elast = 1/(1 - (1/markups))
	sharemu = shares/markups
	# E_\lambda [ d(log \mu_\theta) / dw ]
	dmu_dw = sp.weighted_exp(calvo * (1-passthroughs), shares) * sp.weighted_exp(elast * (1 - calvo), shares) / sp.weighted_exp((calvo*passthroughs + (1 - calvo))*elast, shares)
	dmu_dw = -(dmu_dw + sp.weighted_exp(1 - calvo, shares))
	dmu_dw_norealrigid = -sp.weighted_exp(1 - calvo, shares)
	# dlogA/dw
	dA_dw = sp.weighted_exp((1-calvo)*elast, shares) * sp.weighted_exp(calvo * passthroughs * elast, sharemu)
	dA_dw = dA_dw - sp.weighted_exp((1-calvo)*elast, sharemu) * sp.weighted_exp(calvo * passthroughs * elast, shares)
	dA_dw = dA_dw / sp.weighted_exp(elast*(calvo*passthroughs + (1-calvo)), shares)
	# All slopes are calculated as (dlog Y / dlog w). Will invert all slopes at end to get Philips slope.
	# Slope if just sticky prices
	calvo_channel = 1/(1+inc_elast) * (-frisch * dmu_dw_norealrigid)
	calvo_channel_price = calvo_channel / (1 + dmu_dw_norealrigid)
	# Slope with real rigidities
	real_rigidities_channel =  1/(1+inc_elast) * (-frisch * dmu_dw)
	real_rigid_channel_price = real_rigidities_channel / (1 + dmu_dw)
	# Slope with misallocation channel
	philips_wage_slope = real_rigidities_channel + (1/(1+inc_elast)) * dA_dw
	philips_price_slope = real_rigid_channel_price + (1/(1+inc_elast)) * dA_dw / (1 + dmu_dw)
	return 1/calvo_channel,1/real_rigidities_channel,1/philips_wage_slope,1/calvo_channel_price, 1/real_rigid_channel_price,1/philips_price_slope

def generate_markups(rho, shares, mu_bar):
	theta = np.linspace(0, 1, len(rho)+2)[1:-1]
	rho_fn = interp1d(np.append([0], theta),np.append([1], rho))
	dloglambda = np.diff(np.log(shares)) / np.diff(theta)
	dloglambda_fn = interp1d(np.append([0], theta), np.append([dloglambda[0], dloglambda[0]], dloglambda))
	dmu = lambda theta,mu: mu*(mu-1)*(1-rho_fn(theta))/rho_fn(theta) * dloglambda_fn(theta)
	def mu_from_mu_0(mu_0, mu_bar):
		print('Trying: ', mu_0)
		mu = sp.rk4_integrate(mu_0, theta, dmu)
		mu_bar_pred = 1/(sp.weighted_exp(1/mu, shares))
		print('Predicted mu_bar: ', mu_bar_pred)
		return (mu_bar - mu_bar_pred)**2
	res = minimize(mu_from_mu_0, 1.001, args=(mu_bar), tol=1E-5)
	mu = sp.rk4_integrate(res.x, theta, dmu)
	return mu

def find_pareto_tail_to_match_baseline(keys=None, darw_dat=None):
	def match_baseline(params, gini_target, phil_price_target):
		share_tail,pareto_tail = params
		if share_tail < 0 or pareto_tail < 0:
			return [50, 50]
		shares,markups,passthroughs = generate_mixed_distribution(share_tail, pareto_tail=pareto_tail, keys=keys,darw_dat=darw_dat,truncate_ratio=truncate_ratio)
		#emp_shares = shares/markups / (np.sum(shares/markups))
		gini_pred = 1 - 2*sum(np.cumsum(shares)/sum(shares))/len(shares)
		_,_,phil_wage,_,_,phil_price = philips_slopes_kimball(shares, markups, passthroughs, frisch=0.2, inc_elast=0.2, pi=0.5)
		print(share_tail, pareto_tail, gini_pred, phil_price)
		return [gini_pred - gini_target, phil_price - phil_price_target]
	res = root(match_baseline, [0.002, 2], (0.88, 2.02), tol=1E-4)
	print(res)
	print(res.x)

def generate_mixed_distribution(share_tail, pareto_tail=2, keys=None, darw_dat=None, truncate_ratio=None):
	if darw_dat is None:
		keys,darw_dat = sp.read_darwinian_data()
	rho = darw_dat[keys.index('rho')]
	shares = darw_dat[keys.index('lambda')]
	mu = darw_dat[keys.index('mu')]
	mu = sp.generate_markups(rho, shares, 1.15)
	A = np.ones(len(mu))
	for i in range(1,len(A)):
		A[i] = np.exp(np.log(A[i-1]) + (mu[i-1]-1)/rho[i-1] * (np.log(shares[i]) - np.log(shares[i-1])))
	shares = shares / np.sum(shares)
	y = A * shares / mu
	num_pareto = int(share_tail * len(A))
	num_degenerate = len(A) - num_pareto
	truncate_ratio = 1 - (A[-1]/A[0])**(-1/pareto_tail)
	extra = 450
	pareto_tail_dist = ((1 - np.linspace(0,truncate_ratio,num_pareto+extra))**(-pareto_tail))[:-extra]
	#pareto_tail_dist = (1 - np.linspace(0,0.5,num_pareto+2))**(-pareto_tail)
	median_sales_ind = np.searchsorted(np.cumsum(shares),np.cumsum(shares)[-1]/4)
	A_dist = np.append([A[int(0.6*len(A))]] * num_degenerate, pareto_tail_dist)
	#A_dist = np.append([A[median_sales_ind]] * num_degenerate, pareto_tail_dist[1:-1])
	#A_dist = np.append([1] * num_degenerate, pareto_tail_dist[1:-1])
	A_dist[A_dist > np.max(A)] = np.max(A)
	y_from_Ainv = interp1d(1/A, y, fill_value='extrapolate')
	y_dist = y_from_Ainv(1/A_dist)
	n_dist = y_dist / A_dist
	mu_from_y = interp1d(np.append([0], y), np.append([1], mu))
	rho_from_y = interp1d(np.append([0], y), np.append([1], rho))
	mu_dist = mu_from_y(y_dist)
	rho_dist = rho_from_y(y_dist)
	return (mu_dist*n_dist)/np.sum(mu_dist*n_dist),mu_dist,rho_dist

def generate_truncated_distribution(share_truncate, keys=None, darw_dat=None, truncate_ratio=None):
	if darw_dat is None:
		keys,darw_dat = sp.read_darwinian_data()
	rho = darw_dat[keys.index('rho')]
	shares = darw_dat[keys.index('lambda')]
	mu = darw_dat[keys.index('mu')]
	mu = sp.generate_markups(rho, shares, 1.15)
	A = np.ones(len(mu))
	for i in range(1,len(A)):
		A[i] = np.exp(np.log(A[i-1]) + (mu[i-1]-1)/rho[i-1] * (np.log(shares[i]) - np.log(shares[i-1])))
	shares = shares / np.sum(shares)
	y = A * shares / mu
	num_keep = int((1-share_truncate)*len(A))
	y_dist = y[:num_keep]
	A_dist = A[:num_keep]
	n_dist = y_dist / A_dist
	mu_from_y = interp1d(np.append([0], y), np.append([1], mu))
	rho_from_y = interp1d(np.append([0], y), np.append([1], rho))
	mu_dist = mu_from_y(y_dist)
	rho_dist = rho_from_y(y_dist)
	return (mu_dist*n_dist)/np.sum(mu_dist*n_dist),mu_dist,rho_dist

def minimize_distance(params, A, shares):
	a,b = params 
	if min(a, b) < 0:
		return 1E10
	A_dist = 1 + (A[-1]/A[0] - 1) * beta.ppf(np.linspace(0,1,len(A)), a, b)
	qtls = [0.1, 0.25, 0.5, 0.75, 0.9, 0.925, 0.95, 0.96, 0.97, 0.975, 0.98, 0.985, 0.99, 0.993, 0.995, 0.998]
	err = 0
	for q in qtls:
		err += (A_dist[int(len(A) * q)] - A[int(len(A) * q)])**2
	# Sales-weighted sum of squared errors
	err = sum((A_dist - A)**2 * shares)
	print(params, err)
	return err

def root_distance_estimates(params, keys, darw_dat, gini_target, phil_price_target):
	a,b = params 
	if min(a, b) < 0:
		return [1E10, 1E10]
	shares,markups,passthroughs = generate_beta_distribution(a, b, keys=keys, darw_dat=darw_dat)
	gini = 1 - 2*sum(np.cumsum(shares)/sum(shares))/len(shares)
	_,_,_,_,_,phil_price_slope = philips_slopes_kimball(shares, markups, passthroughs, frisch=0.2, inc_elast=0.2, pi=0.5)
	print([gini-gini_target, phil_price_slope-phil_price_target])
	return [gini-gini_target, phil_price_slope-phil_price_target]

# res = minimize(minimize_distance, [5,5], args=(A, shares))
# res = root(root_distance_estimates, [0.003,1.7], args=(keys,darw_dat,0.88,2.02))
# Match squared error of entire distribution
# a = 0.00255672
# b = 1.69893713
# Match Gini and slope of Belgian distribution
# a = 0.139316
# b = 15.72584302
def generate_beta_distribution(a=0.00255672, b=1.69893713, keys=None, darw_dat=None, truncate_ratio=None):
	if darw_dat is None:
		keys,darw_dat = sp.read_darwinian_data()
	rho = darw_dat[keys.index('rho')]
	shares = darw_dat[keys.index('lambda')]
	mu = darw_dat[keys.index('mu')]
	mu = sp.generate_markups(rho, shares, 1.15)
	A = np.ones(len(mu))
	for i in range(1,len(A)):
		A[i] = np.exp(np.log(A[i-1]) + (mu[i-1]-1)/rho[i-1] * (np.log(shares[i]) - np.log(shares[i-1])))
	shares = shares / np.sum(shares)
	y = A * shares / mu
	A_dist = 1 + (A[-1]/A[0] - 1) * beta.ppf(np.linspace(0,1,len(A)), a, b)
	y_from_Ainv = interp1d(1/A, y, fill_value='extrapolate')
	y_dist = y_from_Ainv(1/A_dist)
	n_dist = y_dist / A_dist
	mu_from_y = interp1d(np.append([0], y), np.append([1], mu))
	rho_from_y = interp1d(np.append([0], y), np.append([1], rho))
	mu_dist = mu_from_y(y_dist)
	rho_dist = rho_from_y(y_dist)
	return (mu_dist*n_dist)/np.sum(mu_dist*n_dist),mu_dist,rho_dist

def find_mixshare_for_gini(gini_target, keys, darw_dat, frisch_fix=0.2, inc_elast_fix=0.2, pi_fix=0.5, truncate_ratio=7000):
	def gini_for_pareto_parameter(pareto_shape):
		print(pareto_shape)
		shares,markups,passthroughs = generate_mixed_distribution(pareto_shape,keys=keys,darw_dat=darw_dat,truncate_ratio=truncate_ratio)
		#emp_shares = shares/markups / (np.sum(shares/markups))
		gini_pred = 1 - 2*sum(np.cumsum(shares)/sum(shares))/len(shares)
		print(gini_pred, shares)
		return (gini_pred - gini_target)**2
	res = minimize(gini_for_pareto_parameter, 0.002, tol=1E-4)
	shares,markups,passthroughs = generate_mixed_distribution(res.x,keys=keys,darw_dat=darw_dat,truncate_ratio=truncate_ratio)#7000)
	calvo_wage_slopes,real_rigid_wage_slopes,phil_wage_slopes,calvo_price_slopes,real_rigid_price_slopes,phil_price_slopes = philips_slopes_kimball(shares, markups, passthroughs, frisch=frisch_fix, inc_elast=inc_elast_fix, pi=pi_fix)
	real_rigid_contrib_wage = calvo_wage_slopes / real_rigid_wage_slopes
	misalloc_contrib_wage = real_rigid_wage_slopes / phil_wage_slopes
	real_rigid_contrib_price = calvo_price_slopes / real_rigid_price_slopes
	misalloc_contrib_price = real_rigid_price_slopes / phil_price_slopes
	overall_flattening_wage = calvo_wage_slopes / phil_wage_slopes
	overall_flattening_price = calvo_price_slopes / phil_price_slopes
	print('------- MATCHING GINI: ', gini_target, ' -------------')
	print('Matched Pareto parameter: ', res.x)
	print('Real rigid, wage: ', real_rigid_contrib_wage)
	print('Misallocation, wage: ', misalloc_contrib_wage)
	print('Overall, wage: ', overall_flattening_wage)
	print('Real rigid, price: ', real_rigid_contrib_price)
	print('Misallocation, price: ', misalloc_contrib_price)
	print('Overall, price: ', overall_flattening_price)



def gen_normalized_pareto_dist(pareto_shape, A):
	truncate_ratio = A[-1]/A[0]
	pareto = (1-np.linspace(0,1,num=len(A)+1)[:-1])**(-1/pareto_shape) - 1
	pareto_norm = pareto * ((truncate_ratio - 1) / max(pareto))
	return pareto_norm + 1

def find_pareto_parameter_for_gini(gini_target, keys, darw_dat, frisch_fix=0.2, inc_elast_fix=0.2, pi_fix=0.5, truncate_ratio=7000):
	def gini_for_pareto_parameter(pareto_shape):
		shares,markups,passthroughs = generate_distributions_from_pareto_prod2(pareto_shape,keys=keys,darw_dat=darw_dat,truncate_ratio=truncate_ratio)
		emp_shares = shares/markups / (np.sum(shares/markups))
		gini_pred = 1 - 2*sum(np.cumsum(emp_shares)/sum(emp_shares))/len(emp_shares)
		return (gini_pred - gini_target)**2
	res = minimize(gini_for_pareto_parameter, 1, tol=1E-4)
	shares,markups,passthroughs = generate_distributions_from_pareto_prod2(res.x,keys=keys,darw_dat=darw_dat,truncate_ratio=truncate_ratio)#7000)
	calvo_wage_slopes,real_rigid_wage_slopes,phil_wage_slopes,calvo_price_slopes,real_rigid_price_slopes,phil_price_slopes = philips_slopes_kimball(shares, markups, passthroughs, frisch=frisch_fix, inc_elast=inc_elast_fix, pi=pi_fix)
	real_rigid_contrib_wage = calvo_wage_slopes / real_rigid_wage_slopes
	misalloc_contrib_wage = real_rigid_wage_slopes / phil_wage_slopes
	real_rigid_contrib_price = calvo_price_slopes / real_rigid_price_slopes
	misalloc_contrib_price = real_rigid_price_slopes / phil_price_slopes
	overall_flattening_wage = calvo_wage_slopes / phil_wage_slopes
	overall_flattening_price = calvo_price_slopes / phil_price_slopes
	print('------- MATCHING GINI: ', gini_target, ' -------------')
	print('Matched Pareto parameter: ', res.x)
	print('Real rigid, wage: ', real_rigid_contrib_wage)
	print('Misallocation, wage: ', misalloc_contrib_wage)
	print('Overall, wage: ', overall_flattening_wage)
	print('Real rigid, price: ', real_rigid_contrib_price)
	print('Misallocation, price: ', misalloc_contrib_price)
	print('Overall, price: ', overall_flattening_price)

def vary_pareto_parameter(frisch_fix=0.2, inc_elast_fix=0.2, pi_fix=0.5, truncate_ratio=7000, savefile=None):
	keys,darw_dat = sp.read_darwinian_data()
	pareto_shape_space = [1.2, 1.15, 1.1, 1.08, 1.06, 1.0, 0.95, 0.93, 0.9, 0.8, 0.7, 0.62, 0.55, 0.5][::-1]
	pareto_shape_space = np.linspace(5, 1.2, 50)
	#pareto_shape_space = np.linspace(0.1, 3.9, 50)[::-1]
	calvo_wage_slopes = np.zeros(len(pareto_shape_space))
	real_rigid_wage_slopes = np.zeros(len(pareto_shape_space))
	phil_wage_slopes = np.zeros(len(pareto_shape_space))
	calvo_price_slopes = np.zeros(len(pareto_shape_space))
	real_rigid_price_slopes = np.zeros(len(pareto_shape_space))
	phil_price_slopes = np.zeros(len(pareto_shape_space))
	hhi = np.zeros(len(pareto_shape_space))
	mu_bar = np.zeros(len(pareto_shape_space))
	E_rho = np.zeros(len(pareto_shape_space))
	gini = np.zeros(len(pareto_shape_space))
	for i,pareto_shape in enumerate(pareto_shape_space):
		shares,markups,passthroughs = generate_distributions_from_pareto_prod2(pareto_shape,keys=keys,darw_dat=darw_dat, truncate_ratio=truncate_ratio)
		if pareto_shape in [1.2, 1.0, 0.8, 0.5]:
			plt.scatter(np.log(shares), markups, label=str(pareto_shape))
		dtheta = 1/sum(shares)
		hhi[i] = sum((shares/sum(shares) * 100)**2)
		mu_bar[i] = 1/sp.weighted_exp(1/markups, shares)
		E_rho[i] = sp.weighted_exp(passthroughs, shares)
		gini[i] = 1 - 2*sum(np.cumsum(shares)/sum(shares))/len(shares)
		calvo_wage_slopes[i],real_rigid_wage_slopes[i],phil_wage_slopes[i],calvo_price_slopes[i],real_rigid_price_slopes[i],phil_price_slopes[i] = philips_slopes_kimball(shares, markups, passthroughs, frisch=frisch_fix, inc_elast=inc_elast_fix, pi=pi_fix)
		print(pareto_shape,gini[i],phil_price_slopes[i])
	plt.legend()
	plt.xlabel('$\\log(\\lambda_\\theta)$')
	plt.ylabel('$\\mu_\\theta$')
	for stat,lab in zip([hhi, mu_bar, E_rho, gini], ['HHI', '$\\bar\\mu$', '$E_\\lambda[\\rho]$', 'Gini']):
		plt.figure()
		plt.plot(pareto_shape_space, stat)
		plt.title(lab)
		plt.xlabel('$\\xi$ (Pareto shape parameter)')
		plt.gca().set_xlim(plt.gca().get_xlim()[::-1])
	axs = sp.plot_slopes(calvo_wage_slopes, real_rigid_wage_slopes, phil_wage_slopes, calvo_price_slopes, real_rigid_price_slopes, phil_price_slopes, gini, 'Gini coefficient (firm employment)', savefile=savefile, reverse=False, overall_flattening=True)
	### Homogeneous firms version
	rho = darw_dat[keys.index('rho')]
	shares = darw_dat[keys.index('lambda')]
	mu = darw_dat[keys.index('mu')]
	mu = sp.generate_markups(rho, shares, 1.15)
	#
	calvo_wage_slopes = np.zeros(len(pareto_shape_space))
	real_rigid_wage_slopes = np.zeros(len(pareto_shape_space))
	phil_wage_slopes = np.zeros(len(pareto_shape_space))
	calvo_price_slopes = np.zeros(len(pareto_shape_space))
	real_rigid_price_slopes = np.zeros(len(pareto_shape_space))
	phil_price_slopes = np.zeros(len(pareto_shape_space))
	for i,pareto_shape in enumerate(pareto_shape_space):
		shares,markups,passthroughs = generate_distributions_from_pareto_prod2(pareto_shape,keys=keys,darw_dat=darw_dat, truncate_ratio=truncate_ratio)
		E_rho = sp.weighted_exp(passthroughs, shares)
		rho_ind = len(rho) - np.searchsorted(rho[::-1],E_rho,side='left')
		weight_ind = (E_rho - rho[rho_ind-1])/(rho[rho_ind] - rho[rho_ind-1])
		E_mu = weight_ind*mu[rho_ind] + (1-weight_ind)*mu[rho_ind-1]
		gini[i] = 1 - 2*sum(np.cumsum(shares)/sum(shares))/len(shares)
		calvo_wage_slopes[i],real_rigid_wage_slopes[i],phil_wage_slopes[i],calvo_price_slopes[i],real_rigid_price_slopes[i],phil_price_slopes[i] = philips_slopes_kimball(np.ones(10), np.ones(10)*E_mu, np.ones(10)*E_rho, frisch=frisch_fix, inc_elast=inc_elast_fix, pi=pi_fix)
		print(pareto_shape,gini[i],phil_price_slopes[i])
	# axs[0].plot(gini, real_rigid_wage_slopes, ls='--', color='red')
	# axs[1].plot(gini, real_rigid_price_slopes, ls='--', color='red')
	#sp.plot_slopes(calvo_wage_slopes, real_rigid_wage_slopes, phil_wage_slopes, calvo_price_slopes, real_rigid_price_slopes, phil_price_slopes, gini, 'Gini coefficient (firm employment)', savefile=savefile, reverse=False, overall_flattening=True)

def vary_lognormal_parameter(frisch_fix=0.2, inc_elast_fix=0.2, pi_fix=0.5, truncate_ratio=7000, savefile=None):
	keys,darw_dat = sp.read_darwinian_data()
	NUM = 100
	pareto_shape_space = np.linspace(0.001, 0.5, NUM)
	to_plot = [pareto_shape_space[0], pareto_shape_space[int(NUM/3)], pareto_shape_space[int(2*NUM/3)], pareto_shape_space[NUM-1]]
	#pareto_shape_space = np.linspace(1.2, 0.5, 50)
	#pareto_shape_space = np.linspace(0.1, 3.9, 50)[::-1]
	calvo_wage_slopes = np.zeros(len(pareto_shape_space))
	real_rigid_wage_slopes = np.zeros(len(pareto_shape_space))
	phil_wage_slopes = np.zeros(len(pareto_shape_space))
	calvo_price_slopes = np.zeros(len(pareto_shape_space))
	real_rigid_price_slopes = np.zeros(len(pareto_shape_space))
	phil_price_slopes = np.zeros(len(pareto_shape_space))
	hhi = np.zeros(len(pareto_shape_space))
	mu_bar = np.zeros(len(pareto_shape_space))
	E_rho = np.zeros(len(pareto_shape_space))
	gini = np.zeros(len(pareto_shape_space))
	for i,pareto_shape in enumerate(pareto_shape_space):
		shares,markups,passthroughs = generate_distributions_from_lognormal_prod(pareto_shape,keys=keys,darw_dat=darw_dat, truncate_ratio=truncate_ratio)
		if pareto_shape in to_plot:
			plt.scatter(np.log(shares), markups, label=str(pareto_shape))
		dtheta = 1/sum(shares)
		hhi[i] = sum((shares/sum(shares) * 100)**2)
		mu_bar[i] = 1/sp.weighted_exp(1/markups, shares)
		E_rho[i] = sp.weighted_exp(passthroughs, shares)
		gini[i] = 1 - 2*sum(np.cumsum(shares)/sum(shares))/len(shares)
		calvo_wage_slopes[i],real_rigid_wage_slopes[i],phil_wage_slopes[i],calvo_price_slopes[i],real_rigid_price_slopes[i],phil_price_slopes[i] = philips_slopes_kimball(shares, markups, passthroughs, frisch=frisch_fix, inc_elast=inc_elast_fix, pi=pi_fix)
		print(pareto_shape,gini[i],phil_price_slopes[i])
	plt.legend()
	plt.xlabel('$\\log(\\lambda_\\theta)$')
	plt.ylabel('$\\mu_\\theta$')
	for stat,lab in zip([hhi, mu_bar, E_rho, gini], ['HHI', '$\\bar\\mu$', '$E_\\lambda[\\rho]$', 'Gini']):
		plt.figure()
		plt.plot(pareto_shape_space, stat)
		plt.title(lab)
		plt.xlabel('$\\xi$ (Pareto shape parameter)')
		plt.gca().set_xlim(plt.gca().get_xlim()[::-1])
	axs = sp.plot_slopes(calvo_wage_slopes, real_rigid_wage_slopes, phil_wage_slopes, calvo_price_slopes, real_rigid_price_slopes, phil_price_slopes, gini, 'Gini coefficient (firm employment)', savefile=savefile, reverse=False, overall_flattening=True)
	### Homogeneous firms version
	rho = darw_dat[keys.index('rho')]
	shares = darw_dat[keys.index('lambda')]
	mu = darw_dat[keys.index('mu')]
	mu = sp.generate_markups(rho, shares, 1.15)
	#
	calvo_wage_slopes = np.zeros(len(pareto_shape_space))
	real_rigid_wage_slopes = np.zeros(len(pareto_shape_space))
	phil_wage_slopes = np.zeros(len(pareto_shape_space))
	calvo_price_slopes = np.zeros(len(pareto_shape_space))
	real_rigid_price_slopes = np.zeros(len(pareto_shape_space))
	phil_price_slopes = np.zeros(len(pareto_shape_space))
	for i,pareto_shape in enumerate(pareto_shape_space):
		shares,markups,passthroughs = generate_distributions_from_lognormal_prod(pareto_shape,keys=keys,darw_dat=darw_dat, truncate_ratio=truncate_ratio)
		E_rho = sp.weighted_exp(passthroughs, shares)
		rho_ind = len(rho) - np.searchsorted(rho[::-1],E_rho,side='left')
		weight_ind = (E_rho - rho[rho_ind-1])/(rho[rho_ind] - rho[rho_ind-1])
		E_mu = weight_ind*mu[rho_ind] + (1-weight_ind)*mu[rho_ind-1]
		gini[i] = 1 - 2*sum(np.cumsum(shares)/sum(shares))/len(shares)
		calvo_wage_slopes[i],real_rigid_wage_slopes[i],phil_wage_slopes[i],calvo_price_slopes[i],real_rigid_price_slopes[i],phil_price_slopes[i] = philips_slopes_kimball(np.ones(10), np.ones(10)*E_mu, np.ones(10)*E_rho, frisch=frisch_fix, inc_elast=inc_elast_fix, pi=pi_fix)
		print(pareto_shape,gini[i],phil_price_slopes[i])
	# axs[0].plot(gini, real_rigid_wage_slopes, ls='--', color='red')
	# axs[1].plot(gini, real_rigid_price_slopes, ls='--', color='red')
	#sp.plot_slopes(calvo_wage_slopes, real_rigid_wage_slopes, phil_wage_slopes, calvo_price_slopes, real_rigid_price_slopes, phil_price_slopes, gini, 'Gini coefficient (firm employment)', savefile=savefile, reverse=False, overall_flattening=True)

def vary_mixeddist(frisch_fix=0.2, inc_elast_fix=0.2, pi_fix=0.5, truncate_ratio=7000, savefile=None):
	keys,darw_dat = sp.read_darwinian_data()
	NUM = 100
	pareto_shape_space = np.linspace(0, 0.02, NUM)
	to_plot = [pareto_shape_space[0], pareto_shape_space[int(NUM/3)], pareto_shape_space[int(2*NUM/3)], pareto_shape_space[NUM-1]]
	#pareto_shape_space = np.linspace(1.2, 0.5, 50)
	#pareto_shape_space = np.linspace(0.1, 3.9, 50)[::-1]
	calvo_wage_slopes = np.zeros(len(pareto_shape_space))
	real_rigid_wage_slopes = np.zeros(len(pareto_shape_space))
	phil_wage_slopes = np.zeros(len(pareto_shape_space))
	calvo_price_slopes = np.zeros(len(pareto_shape_space))
	real_rigid_price_slopes = np.zeros(len(pareto_shape_space))
	phil_price_slopes = np.zeros(len(pareto_shape_space))
	hhi = np.zeros(len(pareto_shape_space))
	mu_bar = np.zeros(len(pareto_shape_space))
	E_rho = np.zeros(len(pareto_shape_space))
	gini = np.zeros(len(pareto_shape_space))
	for i,pareto_shape in enumerate(pareto_shape_space):
		shares,markups,passthroughs = generate_truncated_distribution(pareto_shape, keys=keys,darw_dat=darw_dat, truncate_ratio=truncate_ratio)
		if pareto_shape in to_plot:
			plt.scatter(np.log(shares), markups, label=str(pareto_shape))
		dtheta = 1/sum(shares)
		hhi[i] = sum((shares/sum(shares) * 100)**2)
		mu_bar[i] = 1/sp.weighted_exp(1/markups, shares)
		E_rho[i] = sp.weighted_exp(passthroughs, shares)
		emp_shares = shares/markups / (np.sum(shares/markups))
		gini[i] = 1 - 2*sum(np.cumsum(shares)/sum(shares))/len(shares)
		#gini[i] = 1 - 2*sum(np.cumsum(emp_shares)/sum(emp_shares))/len(emp_shares)
		calvo_wage_slopes[i],real_rigid_wage_slopes[i],phil_wage_slopes[i],calvo_price_slopes[i],real_rigid_price_slopes[i],phil_price_slopes[i] = philips_slopes_kimball(shares, markups, passthroughs, frisch=frisch_fix, inc_elast=inc_elast_fix, pi=pi_fix)
		print(pareto_shape,gini[i],phil_price_slopes[i])
	plt.legend()
	plt.xlabel('$\\log(\\lambda_\\theta)$')
	plt.ylabel('$\\mu_\\theta$')
	for stat,lab in zip([hhi, mu_bar, E_rho, gini], ['HHI', '$\\bar\\mu$', '$E_\\lambda[\\rho]$', 'Gini']):
		plt.figure()
		plt.plot(pareto_shape_space, stat)
		plt.title(lab)
		plt.xlabel('$\\xi$ (Pareto shape parameter)')
		plt.gca().set_xlim(plt.gca().get_xlim()[::-1])
	axs = sp.plot_slopes(calvo_wage_slopes, real_rigid_wage_slopes, phil_wage_slopes, calvo_price_slopes, real_rigid_price_slopes, phil_price_slopes, gini, 'Gini coefficient (firm employment)', savefile=savefile, reverse=False, overall_flattening=True)
	### Homogeneous firms version
	rho = darw_dat[keys.index('rho')]
	shares = darw_dat[keys.index('lambda')]
	mu = darw_dat[keys.index('mu')]
	mu = sp.generate_markups(rho, shares, 1.15)
	#
	# calvo_wage_slopes = np.zeros(len(pareto_shape_space))
	# real_rigid_wage_slopes = np.zeros(len(pareto_shape_space))
	# phil_wage_slopes = np.zeros(len(pareto_shape_space))
	# calvo_price_slopes = np.zeros(len(pareto_shape_space))
	# real_rigid_price_slopes = np.zeros(len(pareto_shape_space))
	# phil_price_slopes = np.zeros(len(pareto_shape_space))
	# for i,pareto_shape in enumerate(pareto_shape_space):
	# 	shares,markups,passthroughs = generate_mixed_distribution(pareto_shape, pareto_tail=2.5, keys=keys,darw_dat=darw_dat, truncate_ratio=truncate_ratio)
	# 	E_rho = sp.weighted_exp(passthroughs, shares)
	# 	rho_ind = len(rho) - np.searchsorted(rho[::-1],E_rho,side='left')
	# 	weight_ind = (E_rho - rho[rho_ind-1])/(rho[rho_ind] - rho[rho_ind-1])
	# 	E_mu = weight_ind*mu[rho_ind] + (1-weight_ind)*mu[rho_ind-1]
	# 	emp_shares = shares/markups / (np.sum(shares/markups))
	# 	gini[i] = 1 - 2*sum(np.cumsum(shares)/sum(shares))/len(shares)
	# 	#gini[i] = 1 - 2*sum(np.cumsum(emp_shares)/sum(emp_shares))/len(emp_shares)
	# 	calvo_wage_slopes[i],real_rigid_wage_slopes[i],phil_wage_slopes[i],calvo_price_slopes[i],real_rigid_price_slopes[i],phil_price_slopes[i] = philips_slopes_kimball(np.ones(10), np.ones(10)*E_mu, np.ones(10)*E_rho, frisch=frisch_fix, inc_elast=inc_elast_fix, pi=pi_fix)
	# 	print(pareto_shape,gini[i],phil_price_slopes[i])
	# axs[0].plot(gini, real_rigid_wage_slopes, ls='--', color='red')
	# axs[1].plot(gini, real_rigid_price_slopes, ls='--', color='red')
	#sp.plot_slopes(calvo_wage_slopes, real_rigid_wage_slopes, phil_wage_slopes, calvo_price_slopes, real_rigid_price_slopes, phil_price_slopes, gini, 'Gini coefficient (firm employment)', savefile=savefile, reverse=False, overall_flattening=True)

def vary_perturb(frisch_fix=0.2, inc_elast_fix=0.2, pi_fix=0.5, truncate_ratio=7000, savefile=None):
	keys,darw_dat = sp.read_darwinian_data()
	NUM = 100
	pareto_shape_space = np.linspace(0.05, 5, NUM)
	#pareto_shape_space = np.linspace(0.05, 0.55, NUM)
	to_plot = [pareto_shape_space[0], pareto_shape_space[int(NUM/3)], pareto_shape_space[int(2*NUM/3)], pareto_shape_space[NUM-1]]
	#pareto_shape_space = np.linspace(1.2, 0.5, 50)
	#pareto_shape_space = np.linspace(0.1, 3.9, 50)[::-1]
	calvo_wage_slopes = np.zeros(len(pareto_shape_space))
	real_rigid_wage_slopes = np.zeros(len(pareto_shape_space))
	phil_wage_slopes = np.zeros(len(pareto_shape_space))
	calvo_price_slopes = np.zeros(len(pareto_shape_space))
	real_rigid_price_slopes = np.zeros(len(pareto_shape_space))
	phil_price_slopes = np.zeros(len(pareto_shape_space))
	hhi = np.zeros(len(pareto_shape_space))
	mu_bar = np.zeros(len(pareto_shape_space))
	E_rho = np.zeros(len(pareto_shape_space))
	gini = np.zeros(len(pareto_shape_space))
	for i,pareto_shape in enumerate(pareto_shape_space):
		shares,markups,passthroughs = perturb_mass(pareto_shape,keys=keys,darw_dat=darw_dat, truncate_ratio=truncate_ratio)
		if pareto_shape in to_plot:
			plt.scatter(np.log(shares), markups, label=str(pareto_shape))
		dtheta = 1/sum(shares)
		hhi[i] = sum((shares/sum(shares) * 100)**2)
		mu_bar[i] = 1/sp.weighted_exp(1/markups, shares)
		E_rho[i] = sp.weighted_exp(passthroughs, shares)
		emp_shares = shares/markups / (np.sum(shares/markups))
		gini[i] = 1 - 2*sum(np.cumsum(shares)/sum(shares))/len(shares)
		#gini[i] = 1 - 2*sum(np.cumsum(emp_shares)/sum(emp_shares))/len(emp_shares)
		calvo_wage_slopes[i],real_rigid_wage_slopes[i],phil_wage_slopes[i],calvo_price_slopes[i],real_rigid_price_slopes[i],phil_price_slopes[i] = philips_slopes_kimball(shares, markups, passthroughs, frisch=frisch_fix, inc_elast=inc_elast_fix, pi=pi_fix)
		print(pareto_shape,gini[i],phil_price_slopes[i])
	plt.legend()
	plt.xlabel('$\\log(\\lambda_\\theta)$')
	plt.ylabel('$\\mu_\\theta$')
	for stat,lab in zip([hhi, mu_bar, E_rho, gini], ['HHI', '$\\bar\\mu$', '$E_\\lambda[\\rho]$', 'Gini']):
		plt.figure()
		plt.plot(pareto_shape_space, stat)
		plt.title(lab)
		plt.xlabel('$\\xi$ (Pareto shape parameter)')
		plt.gca().set_xlim(plt.gca().get_xlim()[::-1])
	axs = sp.plot_slopes(calvo_wage_slopes, real_rigid_wage_slopes, phil_wage_slopes, calvo_price_slopes, real_rigid_price_slopes, phil_price_slopes, gini, 'Gini coefficient (firm employment)', savefile=savefile, reverse=False, overall_flattening=True)
	### Homogeneous firms version
	rho = darw_dat[keys.index('rho')]
	shares = darw_dat[keys.index('lambda')]
	mu = darw_dat[keys.index('mu')]
	mu = sp.generate_markups(rho, shares, 1.15)
	#
	calvo_wage_slopes = np.zeros(len(pareto_shape_space))
	real_rigid_wage_slopes = np.zeros(len(pareto_shape_space))
	phil_wage_slopes = np.zeros(len(pareto_shape_space))
	calvo_price_slopes = np.zeros(len(pareto_shape_space))
	real_rigid_price_slopes = np.zeros(len(pareto_shape_space))
	phil_price_slopes = np.zeros(len(pareto_shape_space))
	for i,pareto_shape in enumerate(pareto_shape_space):
		shares,markups,passthroughs = perturb_mass(pareto_shape,keys=keys,darw_dat=darw_dat, truncate_ratio=truncate_ratio)
		E_rho = sp.weighted_exp(passthroughs, shares)
		rho_ind = len(rho) - np.searchsorted(rho[::-1],E_rho,side='left')
		weight_ind = (E_rho - rho[rho_ind-1])/(rho[rho_ind] - rho[rho_ind-1])
		E_mu = weight_ind*mu[rho_ind] + (1-weight_ind)*mu[rho_ind-1]
		emp_shares = shares/markups / (np.sum(shares/markups))
		gini[i] = 1 - 2*sum(np.cumsum(shares)/sum(shares))/len(shares)
		#gini[i] = 1 - 2*sum(np.cumsum(emp_shares)/sum(emp_shares))/len(emp_shares)
		calvo_wage_slopes[i],real_rigid_wage_slopes[i],phil_wage_slopes[i],calvo_price_slopes[i],real_rigid_price_slopes[i],phil_price_slopes[i] = philips_slopes_kimball(np.ones(10), np.ones(10)*E_mu, np.ones(10)*E_rho, frisch=frisch_fix, inc_elast=inc_elast_fix, pi=pi_fix)
		print(pareto_shape,gini[i],phil_price_slopes[i])
	# axs[0].plot(gini, real_rigid_wage_slopes, ls='--', color='red')
	# axs[1].plot(gini, real_rigid_price_slopes, ls='--', color='red')
	#sp.plot_slopes(calvo_wage_slopes, real_rigid_wage_slopes, phil_wage_slopes, calvo_price_slopes, real_rigid_price_slopes, phil_price_slopes, gini, 'Gini coefficient (firm employment)', savefile=savefile, reverse=False, overall_flattening=True)

def vary_beta(frisch_fix=0.2, inc_elast_fix=0.2, pi_fix=0.5, savefile=None):
	keys,darw_dat = sp.read_darwinian_data()
	a_start = 0.139316
	b_start = 15.72584302
	NUM = 200
	pareto_shape_space = np.linspace(0.8, 15, NUM)
	to_plot = [pareto_shape_space[0], pareto_shape_space[int(NUM/3)], pareto_shape_space[int(2*NUM/3)], pareto_shape_space[NUM-1]]
	#pareto_shape_space = np.linspace(1.2, 0.5, 50)
	#pareto_shape_space = np.linspace(0.1, 3.9, 50)[::-1]
	calvo_wage_slopes = np.zeros(len(pareto_shape_space))
	real_rigid_wage_slopes = np.zeros(len(pareto_shape_space))
	phil_wage_slopes = np.zeros(len(pareto_shape_space))
	calvo_price_slopes = np.zeros(len(pareto_shape_space))
	real_rigid_price_slopes = np.zeros(len(pareto_shape_space))
	phil_price_slopes = np.zeros(len(pareto_shape_space))
	hhi = np.zeros(len(pareto_shape_space))
	mu_bar = np.zeros(len(pareto_shape_space))
	E_rho = np.zeros(len(pareto_shape_space))
	gini = np.zeros(len(pareto_shape_space))
	for i,pareto_shape in enumerate(pareto_shape_space):
		shares,markups,passthroughs = generate_beta_distribution(a=pareto_shape*a_start, b=pareto_shape*b_start, keys=keys,darw_dat=darw_dat)
		if pareto_shape in to_plot:
			plt.scatter(np.log(shares), markups, label=str(pareto_shape))
		dtheta = 1/sum(shares)
		hhi[i] = sum((shares/sum(shares) * 100)**2)
		mu_bar[i] = 1/sp.weighted_exp(1/markups, shares)
		E_rho[i] = sp.weighted_exp(passthroughs, shares)
		emp_shares = shares/markups / (np.sum(shares/markups))
		gini[i] = 1 - 2*sum(np.cumsum(shares)/sum(shares))/len(shares)
		#gini[i] = 1 - 2*sum(np.cumsum(emp_shares)/sum(emp_shares))/len(emp_shares)
		calvo_wage_slopes[i],real_rigid_wage_slopes[i],phil_wage_slopes[i],calvo_price_slopes[i],real_rigid_price_slopes[i],phil_price_slopes[i] = philips_slopes_kimball(shares, markups, passthroughs, frisch=frisch_fix, inc_elast=inc_elast_fix, pi=pi_fix)
		print(pareto_shape,gini[i],phil_price_slopes[i])
	plt.legend()
	plt.xlabel('$\\log(\\lambda_\\theta)$')
	plt.ylabel('$\\mu_\\theta$')
	for stat,lab in zip([hhi, mu_bar, E_rho, gini], ['HHI', '$\\bar\\mu$', '$E_\\lambda[\\rho]$', 'Gini']):
		plt.figure()
		plt.plot(pareto_shape_space, stat)
		plt.title(lab)
		plt.xlabel('$\\xi$ (Pareto shape parameter)')
		plt.gca().set_xlim(plt.gca().get_xlim()[::-1])
	axs = sp.plot_slopes(calvo_wage_slopes, real_rigid_wage_slopes, phil_wage_slopes, calvo_price_slopes, real_rigid_price_slopes, phil_price_slopes, gini, 'Gini coefficient (firm employment)', savefile=savefile, reverse=False, overall_flattening=True)
	### Homogeneous firms version
	# rho = darw_dat[keys.index('rho')]
	# shares = darw_dat[keys.index('lambda')]
	# mu = darw_dat[keys.index('mu')]
	# mu = sp.generate_markups(rho, shares, 1.15)

vary_beta(savefile=None)

plt.show()

